import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator


# 读取训练集
def train_val_generator(data_dir, target_size, batch_size, class_mode=None, subset='training'):
    train_val_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)
    return train_val_datagen.flow_from_directory(
        directory=data_dir,
        target_size=target_size,
        batch_size=batch_size,
        class_mode=class_mode,
        subset=subset
    )


# 读取测试集
def test_generator(data_dir, target_size, batch_size, class_mode=None):
    test_datagen = ImageDataGenerator(rescale=1./255)
    return test_datagen.flow_from_directory(
        directory=data_dir,
        target_size=target_size,
        batch_size=batch_size,
        class_mode=class_mode
    )